import numpy as np
import math

from ase.data.colors import jmol_colors
import ase.data as data
from itertools import chain

class Agent:  # properties of agent entities
    
    def __init__(self, id, name, agents, world):

        # agent name
        self.name = name
        # agent id
        self.id = id
        # atom energy
        self.energy = 0
        self.prev_energy = 0
        # gradient
        self.gradient = np.zeros(3)
        self.prev_gradient = np.zeros(3)
        self.dgrad = np.zeros(3)
        self.prev_dgrad = np.zeros(3)
        self.last_step = np.zeros(3)
        self.prev_last_step = np.zeros(3)
        # action
        self.gnorm = None
        self.prev_gnorm = None
        self.agents = agents
        self.world = world
        
        self.color = jmol_colors[self.world.atoms[id].number]


    # returns atom type encoded as a vector
    def atom_to_vec(self, atom):
        
        res = [data.covalent_radii[atom.number]]

        if self.world.use_mass_feature:
            res.extend([self.world.weight_dict[atom.symbol], 
                        self.world.column_dict[atom.symbol], 
                        self.world.electroneg_dict[atom.symbol]])

        return np.array(res)
    

    # returns features as a vector
    def features(self):
        
        list_of_lists = [self.atom_to_vec(self.atom())]

        if self.world.use_log_gnorm_feature:
            list_of_lists.append([math.log2(max(self.world.min_gnorm, self.gnorm))])

        if self.world.use_gnorm_feature:
            list_of_lists.append([self.gnorm])

        if self.world.use_grad_feature:
            list_of_lists.append(self.gradient)

        if self.world.use_d_grad_feature:
            list_of_lists.append(self.dgrad)

        if self.world.use_dd_grad_feature:
            list_of_lists.append(self.dgrad - self.prev_dgrad)

        if self.world.use_last_step_feature:
            list_of_lists.append(self.last_step)

        if self.world.variable_step_size in ["gnorm", "log_gnorm"]:
            list_of_lists.append([self.world.variable_step_size_coef(self)])

        return np.asarray(list(chain(*list_of_lists)))
    

    # returns the corresponding atom
    def atom(self):

        return self.world.atoms[self.id]


    # returns the distance between the current agent and the given agent with the given offset
    def distance(self, agent, offset=[0, 0, 0]):
        
        return np.linalg.norm(self.vector_to(agent, offset))
    
    
    def vector_to(self, agent, offset=[0, 0, 0]):

        return self.world.atoms.positions[agent.id] + offset @ self.world.atoms.get_cell() - self.world.atoms.positions[self.id]
